{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "29ef44f6-876d-4c92-bfa2-885dac873ad9",
   "metadata": {},
   "source": [
    "# Lesson 10: Image Retrieval"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa61c20e",
   "metadata": {},
   "source": [
    "- In the classroom, the libraries are already installed for you.\n",
    "- If you would like to run this code on your own machine, you can install the following:\n",
    "\n",
    "```\n",
    "    !pip install transformers\n",
    "    !pip install torch\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "190858c7-792d-4695-ba6d-b502c1d25b03",
   "metadata": {},
   "source": [
    "- Here is some code that suppresses warning messages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1faec1a-dcda-4b44-84b8-1631f0bb464e",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from transformers.utils import logging\n",
    "logging.set_verbosity_error()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "847fac54",
   "metadata": {},
   "source": [
    "- Load the model and the processor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67da6b12-64f8-44b1-afd0-717e41346c84",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from transformers import BlipForImageTextRetrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "417991ee-dbf7-48fe-83a6-b6d5700632be",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "model = BlipForImageTextRetrieval.from_pretrained(\n",
    "    \"./models/Salesforce/blip-itm-base-coco\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f80ab891",
   "metadata": {},
   "source": [
    "More info about [Salesforce/blip-itm-base-coco](https://huggingface.co/Salesforce/blip-itm-base-coco)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "478a6f97-a290-42cd-9f2c-284ea7e41d49",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50faf354-91e4-4739-8e2d-7547504e4c92",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "processor = AutoProcessor.from_pretrained(\n",
    "    \"./models/Salesforce/blip-itm-base-coco\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e72f24c-86b8-48d3-af80-07434ea43522",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e17973f-f89a-40a4-9e4a-b46f46d4dfba",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d711df0-1f01-42e5-b397-8ffcb6ef4c12",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "raw_image =  Image.open(\n",
    "    requests.get(img_url, stream=True).raw).convert('RGB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc0a6f43-1607-4364-a7d1-c9799c1a19af",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "raw_image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38db1166-f8c1-4f06-ad67-7083492b6ceb",
   "metadata": {},
   "source": [
    "### Test, if the image matches the text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1edee591-c891-4b4b-91ee-1e9f54642e97",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "text = \"an image of a woman and a dog on the beach\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b41ec7-70d7-49c2-909a-e600ab7e76ca",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "inputs = processor(images=raw_image,\n",
    "                   text=text,\n",
    "                   return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f818f97-3b9a-4024-8cab-31b615d32632",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db51ff2e-74c2-4ea1-b271-48d8f15e7fd6",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "itm_scores = model(**inputs)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c962d20d-6091-4bba-9cc9-afa9a3bef79b",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "itm_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7615e88d-a6c1-4363-bcab-0f11e93d20af",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc0a433f-d1f1-45c7-b776-1472b8a7cbcd",
   "metadata": {},
   "source": [
    "- Use a softmax layer to get the probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0749cbf-be41-4b47-92ac-a0f6515ff31e",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "itm_score = torch.nn.functional.softmax(\n",
    "    itm_scores,dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adc3f239-f95e-4451-9a5b-b8b83011eeb4",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "itm_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f000eb6-2b6a-48c5-8481-a3bbdd5b3c5d",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "print(f\"\"\"\\\n",
    "The image and text are matched \\\n",
    "with a probability of {itm_score[0][1]:.4f}\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "416bc237",
   "metadata": {},
   "source": [
    "### Try it yourself! \n",
    "- Try this model with your own images and texts!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02012ee6-68f8-4b18-ba64-9596cdcd9f9d",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
